import torch
import torch.nn as nn


class PlasticityModel(nn.Module):

    def __init__(self, yield_stress: float = 0.07):
        """
        Define trainable plastic yield stress parameter with enforced numerical stability.

        Args:
            yield_stress (float): yield stress controlling deviatoric plastic flow magnitude.
        """
        super().__init__()
        self.yield_stress = nn.Parameter(torch.tensor(yield_stress))

    def forward(self, F: torch.Tensor) -> torch.Tensor:
        """
        Compute plasticity-corrected deformation gradient by shrinking deviatoric logarithmic strain.

        Args:
            F (torch.Tensor): deformation gradient tensor (B, 3, 3).

        Returns:
            F_corrected (torch.Tensor): corrected deformation gradient tensor (B, 3, 3).
        """
        # SVD decomposition
        U, sigma, Vh = torch.linalg.svd(F)                            # (B, 3, 3), (B, 3), (B, 3, 3)

        # Clamp singular values for stability
        sigma_clamped = torch.clamp_min(sigma, 1e-6)                  # (B, 3)

        # Compute logarithmic principal strain
        epsilon = torch.log(sigma_clamped)                            # (B, 3)

        # Volumetric part (mean)
        epsilon_mean = epsilon.mean(dim=1, keepdim=True)              # (B, 1)

        # Deviatoric strain
        epsilon_dev = epsilon - epsilon_mean                          # (B, 3)

        # Norm of deviatoric strain
        epsilon_dev_norm = torch.linalg.norm(epsilon_dev, dim=1, keepdim=True)  # (B, 1)

        # Enforce minimum yield stress to avoid numerical instability
        yield_stress = torch.clamp_min(self.yield_stress, 0.05)       # scalar

        # Clamp norm for division
        epsilon_dev_norm_safe = torch.clamp_min(epsilon_dev_norm, 1e-12)        # (B, 1)

        # Compute plastic correction magnitude delta_gamma
        delta_gamma = epsilon_dev_norm - yield_stress                  # (B, 1)
        delta_gamma_clamped = torch.clamp_min(delta_gamma, 0.0)        # (B, 1)

        # Scaling factor for deviatoric strain correction
        scale = 1.0 - delta_gamma_clamped / epsilon_dev_norm_safe       # (B, 1)
        scale = torch.clamp_min(scale, 0.0)                            # (B, 1)

        # Apply plastic correction to deviatoric strain
        epsilon_dev_corrected = epsilon_dev * scale                     # (B, 3)

        # Recombine volumetric and deviatoric parts
        epsilon_corrected = epsilon_mean + epsilon_dev_corrected       # (B, 3)

        # Calculate corrected singular values
        sigma_corrected = torch.exp(epsilon_corrected)                  # (B, 3)

        # Reconstruct corrected deformation gradient
        F_corrected = U @ torch.diag_embed(sigma_corrected) @ Vh       # (B, 3, 3)

        return F_corrected


class ElasticityModel(nn.Module):

    def __init__(self, youngs_modulus_log: float = 9.55, poissons_ratio_sigmoid: float = 2.50):
        """
        Define trainable Young's modulus and Poisson's ratio with physically realistic bounds.

        Args:
            youngs_modulus_log (float): logarithm of Young's modulus.
            poissons_ratio_sigmoid (float): raw parameter to be passed through sigmoid for Poisson's ratio.
        """
        super().__init__()
        self.youngs_modulus_log = nn.Parameter(torch.tensor(youngs_modulus_log))
        self.poissons_ratio_sigmoid = nn.Parameter(torch.tensor(poissons_ratio_sigmoid))

    def forward(self, F: torch.Tensor) -> torch.Tensor:
        """
        Compute Kirchhoff stress tensor from deformation gradient with corotated elasticity.

        Args:
            F (torch.Tensor): deformation gradient tensor (B, 3, 3).

        Returns:
            kirchhoff_stress (torch.Tensor): Kirchhoff stress tensor (B, 3, 3).
        """
        B = F.shape[0]

        # Recover material parameters
        E = self.youngs_modulus_log.exp()                              # scalar
        nu_raw = self.poissons_ratio_sigmoid.sigmoid()                 # (0,1)
        nu = nu_raw * 0.45                                              # scale to max 0.45 Poisson ratio (~stable and compressible)

        # Lamé parameters
        mu = E / (2.0 * (1.0 + nu))                                    # scalar
        lam = E * nu / ((1.0 + nu) * (1.0 - 2.0 * nu))                 # scalar

        # Compute SVD
        U, sigma, Vh = torch.linalg.svd(F)                            # (B, 3, 3), (B, 3), (B, 3, 3)

        # Clamp singular values to prevent numerical issues
        sigma_clamped = torch.clamp_min(sigma, 1e-6)                  # (B, 3)

        # Compute rotation part R
        R = U @ Vh                                                    # (B, 3, 3)

        # Expand mu for broadcasting
        if mu.dim() > 0:
            mu_expanded = mu.view(-1, 1, 1)                         # (B, 1, 1)
        else:
            mu_expanded = mu                                          # scalar

        # Corotated stress part: 2 * mu * (F - R)
        corotated_stress = 2.0 * mu_expanded * (F - R)               # (B, 3, 3)

        # Compute determinant J and clamp for stability
        J = torch.linalg.det(F)                                       # (B,)
        J_clamped = torch.clamp_min(J, 1e-8)                         # (B,)

        # Identity tensor I (1, 3, 3)
        I = torch.eye(3, dtype=F.dtype, device=F.device).unsqueeze(0) # (1, 3, 3)

        # Expand and reshape parameters for broadcasting
        if lam.dim() > 0:
            lam_expanded = lam.view(-1, 1, 1)                        # (B, 1, 1)
        else:
            lam_expanded = lam                                         # scalar

        J_expanded = J_clamped.view(-1, 1, 1)                         # (B, 1, 1)
        J_minus_1_expanded = (J_clamped - 1.0).view(-1, 1, 1)         # (B, 1, 1)

        # Volumetric stress: lambda * J * (J - 1) * I
        volumetric_stress = lam_expanded * J_expanded * J_minus_1_expanded * I  # (B, 3, 3)

        # First Piola-Kirchhoff stress
        P = corotated_stress + volumetric_stress                      # (B, 3, 3)

        # Transpose of deformation gradient
        Ft = F.transpose(1, 2)                                        # (B, 3, 3)

        # Kirchhoff stress tensor: tau = P @ F^T
        kirchhoff_stress = P @ Ft                                     # (B, 3, 3)

        return kirchhoff_stress
